import sys

import torch

sys.path.append("src")
from pruner.utils import find_layers


def weight_importance(model, tokenizer, device=torch.device("cuda:0")):
    """
    Compute the weight importance of each layer in the model.
    """
    try:
        layers = model.model.layers
    except:
        layers = model.model.decoder.layers
    W_metrics = {}
    cnt = 0
    for i in range(len(layers)):
        layer = layers[i]
        subset = find_layers(layer)

        for name in subset:
            W = subset[name].weight.data
            W_metrics[cnt] = W.abs()
            cnt += 1
    return W_metrics
